{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qFdPvlXBOdUN"
      },
      "source": [
        "# 使用 tf.data API 获得更高性能"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://tensorflow.google.cn/guide/data_performance\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\">在 TensorFlow.org 上查看 </a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/data_performance.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\">在 Google Colab 中运行 </a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/data_performance.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\">在 GitHub 中查看源代码</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/guide/data_performance.ipynb\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\">下载笔记本</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xHxb-dlhMIzW"
      },
      "source": [
        "## 概述\n",
        "\n",
        "GPU 和 TPU 能够极大缩短执行单个训练步骤所需的时间。为了达到最佳性能，需要高效的输入流水线，以在当前步骤完成之前为下一步提供数据。`tf.data` API 有助于构建灵活高效的输入流水线。本文档演示了如何使用 `tf.data` API 构建高性能的 TensorFlow 输入流水线。\n",
        "\n",
        "继续之前，请阅读“[构建 TensorFlow 输入流水线](https://render.githubusercontent.com/view/data.ipynb)”指南，了解如何使用 `tf.data` API。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UhNtHfuxCGVy"
      },
      "source": [
        "## 资源\n",
        "\n",
        "- [构建 TensorFlow 输入流水线](./data.ipynb)\n",
        "- `tf.data.Dataset` API\n",
        "- 使用 TF Profiler 分析<code>tf.data</code> 性能"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MUXex9ctTuDB"
      },
      "source": [
        "## 设置"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IqR2PQG4ZaZ0"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "import time"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QthTHCKF-jKD"
      },
      "source": [
        "在本指南中，您将迭代数据集并衡量性能。制定可重现的性能基准可能很困难，不同的因素会对其产生影响：\n",
        "\n",
        "- 当前的 CPU 负载\n",
        "- 网络流量\n",
        "- 缓存等复杂机制\n",
        "\n",
        "因此，要提供可重现的基准，需要构建人工样本。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3bU5gsSI-jKF"
      },
      "source": [
        "### 数据集\n",
        "\n",
        "定义一个继承自 `tf.data.Dataset` 的类，称为 `ArtificialDataset`。此数据集：\n",
        "\n",
        "- 会生成 `num_samples` 样本（默认数量为 3）\n",
        "- 会在第一项模拟打开文件之前休眠一段时间\n",
        "- 会在生成每一个模拟从文件读取数据的项之前休眠一段时间"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zUQv4kCd-jKH"
      },
      "outputs": [],
      "source": [
        "class ArtificialDataset(tf.data.Dataset):\n",
        "    def _generator(num_samples):\n",
        "        # Opening the file\n",
        "        time.sleep(0.03)\n",
        "        \n",
        "        for sample_idx in range(num_samples):\n",
        "            # Reading data (line, record) from the file\n",
        "            time.sleep(0.015)\n",
        "            \n",
        "            yield (sample_idx,)\n",
        "    \n",
        "    def __new__(cls, num_samples=3):\n",
        "        return tf.data.Dataset.from_generator(\n",
        "            cls._generator,\n",
        "            output_types=tf.dtypes.int64,\n",
        "            output_shapes=(1,),\n",
        "            args=(num_samples,)\n",
        "        )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O9y1WjNv-jKL"
      },
      "source": [
        "此数据集类似 `tf.data.Dataset.range`，在开头和每个样本之间添加了固定延迟。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FGK1Y4jn-jKM"
      },
      "source": [
        "### 训练循环\n",
        "\n",
        "编写一个虚拟的训练循环，以测量迭代数据集所用的时间。训练时间是模拟的。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MIaM3u00-jKP"
      },
      "outputs": [],
      "source": [
        "def benchmark(dataset, num_epochs=2):\n",
        "    start_time = time.perf_counter()\n",
        "    for epoch_num in range(num_epochs):\n",
        "        for sample in dataset:\n",
        "            # Performing a training step\n",
        "            time.sleep(0.01)\n",
        "    tf.print(\"Execution time:\", time.perf_counter() - start_time)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KK58SuXS-jKT"
      },
      "source": [
        "## 优化性能\n",
        "\n",
        "为了展示如何优化性能，下面我们将优化 `ArtificialDataset` 的性能。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xi8t26y7-jKV"
      },
      "source": [
        "### 朴素的方法\n",
        "\n",
        "先从不使用任何技巧的朴素流水线开始，按原样迭代数据集。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_gP7J1y4-jKY"
      },
      "outputs": [],
      "source": [
        "benchmark(ArtificialDataset())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Lxeat5dH-jKf"
      },
      "source": [
        "从后台可以看到执行时间的花费情况：\n",
        "\n",
        "![Naive](https://tensorflow.google.cn/guide/images/data_performance/naive.svg)\n",
        "\n",
        "可以看到，执行一个训练步骤涉及以下操作：\n",
        "\n",
        "- 打开文件（如果尚未打开）\n",
        "- 从文件获取数据条目\n",
        "- 使用数据进行训练\n",
        "\n",
        "但是，在类似这里的朴素同步实现中，当流水线在获取数据时，模型会处于空闲状态。相反，当模型在进行训练时，输入流水线会处于空闲状态。因此，训练步骤的用时是打开、读取和训练时间的总和。\n",
        "\n",
        "接下来的各部分将基于此输入流水线，演示设计高效 TensorFlow 输入流水线的最佳做法。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mfukBGNz-jKh"
      },
      "source": [
        "### 预提取\n",
        "\n",
        "预提取会与训练步骤的预处理和模型执行重叠。在模型执行第 `s` 步训练的同时，输入流水线会读取第 `s+1` 步的数据。这样做能够最大程度减少训练的单步用时（而非总和），并减少提取数据所需的时间。\n",
        "\n",
        "`tf.data` API 提供了 `tf.data.Dataset.prefetch` 转换。它可以用来将数据的生成时间和使用时间分离。特别是，该转换会使用后台线程和内部缓冲区在请求元素之前从输入数据集中预提取这些元素。要预提取的元素数量应等于（或可能大于）单个训练步骤使用的批次数。您可以手动调整这个值，或者将其设置为 `tf.data.experimental.AUTOTUNE`，这将提示 `tf.data` 运行时在运行期间动态调整该值。\n",
        "\n",
        "注意，只要有机会将“生产者”的工作和“使用者”的工作重叠，预提取转换就能带来好处。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DHpUVqH1-jKi"
      },
      "outputs": [],
      "source": [
        "benchmark(\n",
        "    ArtificialDataset()\n",
        "    .prefetch(tf.data.experimental.AUTOTUNE)\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h7z_kzo--jKn"
      },
      "source": [
        "![Prefetched](https://tensorflow.google.cn/guide/images/data_performance/prefetched.svg)\n",
        "\n",
        "这次您可以看到，在针对样本 0 运行训练步骤的同时，输入流水线正在读取样本 1 的数据，依此类推。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "52QMKfaY-jKq"
      },
      "source": [
        "### 并行数据提取\n",
        "\n",
        "在实际设置中，输入数据可能会远程存储（例如，GCS 或 HDFS）。由于在本地存储空间和远程存储空间之间存在以下差异，在本地读取数据时运行良好的数据集流水线可能会在远程读取数据时成为 I/O 瓶颈：\n",
        "\n",
        "- **到达第一字节用时**：从远程存储空间中读取文件的第一个字节所花费的时间要比从本地存储空间中读取所花费的时间长几个数量级。\n",
        "- **读取吞吐量**：虽然远程存储空间通常提供较大的聚合带宽，但读取单个文件可能只能使用此带宽的一小部分。\n",
        "\n",
        "此外，将原始字节加载到内存中后，可能还需要对数据进行反序列化和/或解密（例如 [protobuf](https://developers.google.com/protocol-buffers/)），这需要进行额外计算。无论数据是本地存储还是远程存储，都会存在此开销，但如果没有高效地预提取数据，在远程情况下会更糟。\n",
        "\n",
        "为了减轻各种数据提取开销的影响，可以使用 `tf.data.Dataset.interleave` 转换来并行化数据加载步骤，从而交错其他数据集（如数据文件读取器）的内容。可以通过 `cycle_length` 参数指定想要重叠的数据集数量，并通过 `num_parallel_calls` 参数指定并行度级别。与 `prefetch` 转换类似，`interleave` 转换也支持 `tf.data.experimental.AUTOTUNE`，它将让 `tf.data` 运行时决定要使用的并行度级别。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gs8O8Vbu-jKu"
      },
      "source": [
        "#### 顺序交错\n",
        "\n",
        "`tf.data.Dataset.interleave` 转换的默认参数会使其按顺序交错两个数据集中的单个样本。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fDH12GiK-jKw"
      },
      "outputs": [],
      "source": [
        "benchmark(\n",
        "    tf.data.Dataset.range(2)\n",
        "    .interleave(ArtificialDataset)\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "78CsSOnf-jK0"
      },
      "source": [
        "![Sequential interleave](https://tensorflow.google.cn/guide/images/data_performance/sequential_interleave.svg)\n",
        "\n",
        "该图可以展示 `interleave` 转换的行为，从两个可用的数据集中交替预提取样本。但是，这里不涉及性能改进。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j3cqqmYl-jK2"
      },
      "source": [
        "#### 并行交错\n",
        "\n",
        "现在，使用 `interleave` 转换的 `num_parallel_calls` 参数。这样可以并行加载多个数据集，从而减少等待打开文件的时间。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a3FQcTPY-jK4"
      },
      "outputs": [],
      "source": [
        "benchmark(\n",
        "    tf.data.Dataset.range(2)\n",
        "    .interleave(\n",
        "        ArtificialDataset,\n",
        "        num_parallel_calls=tf.data.experimental.AUTOTUNE\n",
        "    )\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RxRLPB6C-jLA"
      },
      "source": [
        "![Parallel interleave](https://tensorflow.google.cn/guide/images/data_performance/parallel_interleave.svg)\n",
        "\n",
        "这次，两个数据集的读取并行进行，从而减少了全局数据处理时间。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5ZCLFWyv-jLB"
      },
      "source": [
        "### 并行数据转换\n",
        "\n",
        "准备数据时，可能需要对输入元素进行预处理。为此，`tf.data` API 提供了 `tf.data.Dataset.map` 转换，该转换会将用户定义的函数应用于输入数据集的每个元素。由于输入元素彼此独立，可以在多个 CPU 核心之间并行预处理。为了实现这一点，类似 `prefetch` 和 `interleave` 转换，`map` 转换也提供了`num_parallel_calls` 参数来指定并行度级别。\n",
        "\n",
        "`num_parallel_calls` 参数最佳值的选择取决于您的硬件、训练数据的特性（如数据大小和形状）、映射函数的开销，以及同一时间在 CPU 上进行的其他处理。一个简单的试探法是使用可用的 CPU 核心数。但是，对于 `prefetch` 和 `interleave` 转换，`map` 转换支持 `tf.data.experimental.AUTOTUNE`，它将让 `tf.data` 运行时决定要使用的并行度级别。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GSkKetpx-jLD"
      },
      "outputs": [],
      "source": [
        "def mapped_function(s):\n",
        "    # Do some hard pre-processing\n",
        "    tf.py_function(lambda: time.sleep(0.03), [], ())\n",
        "    return s"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wiU7W_QC-jLI"
      },
      "source": [
        "#### 顺序映射\n",
        "\n",
        "首先使用不具有并行度的 `map` 转换作为基准示例。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZSBvDpJG-jLL"
      },
      "outputs": [],
      "source": [
        "benchmark(\n",
        "    ArtificialDataset()\n",
        "    .map(mapped_function)\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ngwMTDb6-jLR"
      },
      "source": [
        "![Sequential mapping](https://tensorflow.google.cn/guide/images/data_performance/sequential_map.svg)\n",
        "\n",
        "对于[朴素方法](#The-naive-approach)来说，单次迭代的用时就是花费在打开、读取、预处理（映射）和训练步骤上的时间总和。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U-10PE1D-jLU"
      },
      "source": [
        "#### 并行映射\n",
        "\n",
        "现在，使用相同的预处理函数，但将其并行应用于多个样本。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F8AYLZbg-jLV"
      },
      "outputs": [],
      "source": [
        "benchmark(\n",
        "    ArtificialDataset()\n",
        "    .map(\n",
        "        mapped_function,\n",
        "        num_parallel_calls=tf.data.experimental.AUTOTUNE\n",
        "    )\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-MoJklzP-jLe"
      },
      "source": [
        "![Parallel mapping](https://tensorflow.google.cn/guide/images/data_performance/parallel_map.svg)\n",
        "\n",
        "现在，您可以在图上看到预处理步骤重叠，从而减少了单次迭代的总时间。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZY1Q9kJO-jLh"
      },
      "source": [
        "### 缓存\n",
        "\n",
        "`tf.data.Dataset.cache` 转换可以在内存中或本地存储空间中缓存数据集。这样可以避免一些运算（如文件打开和数据读取）在每个周期都被执行。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xieLApaI-jLi"
      },
      "outputs": [],
      "source": [
        "benchmark(\n",
        "    ArtificialDataset()\n",
        "    .map(  # Apply time consuming operations before cache\n",
        "        mapped_function\n",
        "    ).cache(\n",
        "    ),\n",
        "    5\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KeMgW9XI-jLn"
      },
      "source": [
        "![Cached dataset](https://tensorflow.google.cn/guide/images/data_performance/cached_dataset.svg)\n",
        "\n",
        "缓存数据集时，仅在第一个周期执行一次 `cache` 之前的转换（如文件打开和数据读取）。后续周期将重用通过 `cache` 转换缓存的数据。\n",
        "\n",
        "如果传递到 `map` 转换的用户定义函数开销很大，可在 `map` 转换后应用 `cache` 转换，只要生成的数据集仍然可以放入内存或本地存储空间即可。如果用户定义函数增加了存储数据集所需的空间（超出缓存容量），在 `cache` 转换后应用该函数，或者考虑在训练作业之前对数据进行预处理以减少资源使用。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i3NtGI3r-jLp"
      },
      "source": [
        "### 向量化映射\n",
        "\n",
        "调用传递给 `map` 转换的用户定义函数会产生与调度和执行用户定义函数相关的开销。我们建议对用户定义函数进行向量化处理（即，让它一次运算一批输入）并在 `map` 转换*之前*应用 `batch` 转换。\n",
        "\n",
        "为了说明这种良好做法，不适合使用您的人工数据集。调度延迟大约为 10 微妙（10e-6 秒），远远小于 `ArtificialDataset` 中使用的数十毫秒，因此很难看出它的影响。\n",
        "\n",
        "在本示例中，我们使用基本 `tf.data.Dataset.range` 函数并将训练循环简化为最简形式。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xqtiYPmb-jLt"
      },
      "outputs": [],
      "source": [
        "fast_dataset = tf.data.Dataset.range(10000)\n",
        "\n",
        "def fast_benchmark(dataset, num_epochs=2):\n",
        "    start_time = time.perf_counter()\n",
        "    for _ in tf.data.Dataset.range(num_epochs):\n",
        "        for _ in dataset:\n",
        "            pass\n",
        "    tf.print(\"Execution time:\", time.perf_counter() - start_time)\n",
        "    \n",
        "def increment(x):\n",
        "    return x+1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Fj2gmsMT-jL5"
      },
      "source": [
        "#### 标量映射"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Imn3SslJ-jMA"
      },
      "outputs": [],
      "source": [
        "fast_benchmark(\n",
        "    fast_dataset\n",
        "    # Apply function one item at a time\n",
        "    .map(increment)\n",
        "    # Batch\n",
        "    .batch(256)\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BWUNbPqv-jMF"
      },
      "source": [
        "![Scalar map](https://tensorflow.google.cn/guide/images/data_performance/scalar_map.svg)\n",
        "\n",
        "上图说明了正在发生的事情（样本较少）。您可以看到已将映射函数应用于每个样本。虽然此函数速度很快，但会产生一些开销影响时间性能。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tDVSM0A--jMG"
      },
      "source": [
        "#### 向量化映射"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nAw1mDLw-jMI"
      },
      "outputs": [],
      "source": [
        "fast_benchmark(\n",
        "    fast_dataset\n",
        "    .batch(256)\n",
        "    # Apply function on a batch of items\n",
        "    # The tf.Tensor.__add__ method already handle batches\n",
        "    .map(increment)\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DbMteMY9-jMO"
      },
      "source": [
        "![Vectorized map](https://tensorflow.google.cn/guide/images/data_performance/vectorized_map.svg)\n",
        "\n",
        "这次，映射函数被调用了一次，并被应用于一批样本。虽然该函数可能需要花费更多时间执行，但开销仅出现了一次，从而改善了整体时间性能。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hfueG0Wj-jMR"
      },
      "source": [
        "### 减少内存占用\n",
        "\n",
        "许多转换（包括 `interleave`、`prefetch` 和 `shuffle`）会维护元素的内部缓冲区。如果传递给 `map` 转换的用户定义函数更改了元素大小，则映射转换和缓冲元素的转换的顺序会影响内存使用量。通常，我们建议选择能够降低内存占用的顺序，除非需要不同的顺序以提高性能。\n",
        "\n",
        "#### 缓存部分计算\n",
        "\n",
        "建议在 `map` 转换后缓存数据集，除非此转换会使数据过大而不适合放在内存中。如果映射函数可以分成两个部分，则能实现折衷：一个耗时的部分和一个消耗内存的部分。在这种情况下，您可以按如下方式将转换链接起来：\n",
        "\n",
        "```python\n",
        "dataset.map(time_consuming_mapping).cache().map(memory_consuming_mapping)\n",
        "```\n",
        "\n",
        "这样，耗时部分仅在第一个周期执行，从而避免了使用过多缓存空间。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MYOHG69M-jMT"
      },
      "source": [
        "## 最佳做法总结\n",
        "\n",
        "以下是设计高效 TensorFlow 输入流水线的最佳做法总结：\n",
        "\n",
        "- [使用 `prefetch` 转换](#Pipelining)使生产者和使用者的工作重叠。\n",
        "- 使用 `interleave` 转换实现[并行数据读取转换](#Parallelizing-data-extraction)。\n",
        "- 通过设置 `num_parallel_calls` 参数实现[并行 `map` 转换](#Parallelizing-data-transformation)。\n",
        "- 第一个周期[使用 `cache` 转换](#Caching)将数据缓存在内存中。\n",
        "- [向量化](#Map-and-batch)传递给 `map` 转换的用户定义函数\n",
        "- 应用 `interleave`、`prefetch` 和 `shuffle` 转换时[减少内存使用量](#Reducing-memory-footprint)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mP_EMFsQ-jMU"
      },
      "source": [
        "## 重现图表\n",
        "\n",
        "注：本笔记本的剩余部分是关于如何重现上述图表的，请随意使用以下代码，但请了解其并非本教程的主要内容。\n",
        "\n",
        "要更深入地了解 `tf.data.Dataset` API，您可以使用自己的流水线。下面是用来绘制本指南中图像的代码。这可以作为一个好的起点，展示了一些常见困难的解决方法，例如：\n",
        "\n",
        "- 执行时间的可重现性；\n",
        "- 映射函数的 Eager Execution；\n",
        "- `interleave` 转换的可调用对象。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7M_jFLer-jMV"
      },
      "outputs": [],
      "source": [
        "import itertools\n",
        "from collections import defaultdict\n",
        "\n",
        "import numpy as np\n",
        "import matplotlib as mpl\n",
        "import matplotlib.pyplot as plt"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z3pjnxtK-jMa"
      },
      "source": [
        "### 数据集\n",
        "\n",
        "与 `ArtificialDataset` 类似，您可以构建一个返回每步用时的数据集。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OgGl4U7t-jMc"
      },
      "outputs": [],
      "source": [
        "class TimeMeasuredDataset(tf.data.Dataset):\n",
        "    # OUTPUT: (steps, timings, counters)\n",
        "    OUTPUT_TYPES = (tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32)\n",
        "    OUTPUT_SHAPES = ((2, 1), (2, 2), (2, 3))\n",
        "    \n",
        "    _INSTANCES_COUNTER = itertools.count()  # Number of datasets generated\n",
        "    _EPOCHS_COUNTER = defaultdict(itertools.count)  # Number of epochs done for each dataset\n",
        "    \n",
        "    def _generator(instance_idx, num_samples):\n",
        "        epoch_idx = next(TimeMeasuredDataset._EPOCHS_COUNTER[instance_idx])\n",
        "        \n",
        "        # Opening the file\n",
        "        open_enter = time.perf_counter()\n",
        "        time.sleep(0.03)\n",
        "        open_elapsed = time.perf_counter() - open_enter\n",
        "        \n",
        "        for sample_idx in range(num_samples):\n",
        "            # Reading data (line, record) from the file\n",
        "            read_enter = time.perf_counter()\n",
        "            time.sleep(0.015)\n",
        "            read_elapsed = time.perf_counter() - read_enter\n",
        "            \n",
        "            yield (\n",
        "                [(\"Open\",), (\"Read\",)],\n",
        "                [(open_enter, open_elapsed), (read_enter, read_elapsed)],\n",
        "                [(instance_idx, epoch_idx, -1), (instance_idx, epoch_idx, sample_idx)]\n",
        "            )\n",
        "            open_enter, open_elapsed = -1., -1.  # Negative values will be filtered\n",
        "            \n",
        "    \n",
        "    def __new__(cls, num_samples=3):\n",
        "        return tf.data.Dataset.from_generator(\n",
        "            cls._generator,\n",
        "            output_types=cls.OUTPUT_TYPES,\n",
        "            output_shapes=cls.OUTPUT_SHAPES,\n",
        "            args=(next(cls._INSTANCES_COUNTER), num_samples)\n",
        "        )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YQqDP4jk-jMj"
      },
      "source": [
        "此数据集会提供形状为 `[[2, 1], [2, 2], [2, 3]]` 且类型为 `[tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32]` 的样本。每个样本为：\n",
        "\n",
        "```\n",
        "(   [(\"Open\"), (\"Read\")],   [(t0, d), (t0, d)],   [(i, e, -1), (i, e, s)] )\n",
        "```\n",
        "\n",
        "其中：\n",
        "\n",
        "- `Open` 和 `Read` 是步骤标识符\n",
        "- `t0` 是相应步骤开始时的时间戳\n",
        "- `d` 是在相应步骤中花费的时间\n",
        "- `i` 是实例索引\n",
        "- `e` 是周期索引（数据集被迭代的次数）\n",
        "- `s` 是样本索引"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IQK913bB-jMm"
      },
      "source": [
        "### 迭代循环\n",
        "\n",
        "使迭代循环稍微复杂一点，以汇总所有计时。这仅适用于生成上述样本的数据集。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zAy-K_Cq-jMn"
      },
      "outputs": [],
      "source": [
        "def timelined_benchmark(dataset, num_epochs=2):\n",
        "    # Initialize accumulators\n",
        "    steps_acc = tf.zeros([0, 1], dtype=tf.dtypes.string)\n",
        "    times_acc = tf.zeros([0, 2], dtype=tf.dtypes.float32)\n",
        "    values_acc = tf.zeros([0, 3], dtype=tf.dtypes.int32)\n",
        "    \n",
        "    start_time = time.perf_counter()\n",
        "    for epoch_num in range(num_epochs):\n",
        "        epoch_enter = time.perf_counter()\n",
        "        for (steps, times, values) in dataset:\n",
        "            # Record dataset preparation informations\n",
        "            steps_acc = tf.concat((steps_acc, steps), axis=0)\n",
        "            times_acc = tf.concat((times_acc, times), axis=0)\n",
        "            values_acc = tf.concat((values_acc, values), axis=0)\n",
        "            \n",
        "            # Simulate training time\n",
        "            train_enter = time.perf_counter()\n",
        "            time.sleep(0.01)\n",
        "            train_elapsed = time.perf_counter() - train_enter\n",
        "            \n",
        "            # Record training informations\n",
        "            steps_acc = tf.concat((steps_acc, [[\"Train\"]]), axis=0)\n",
        "            times_acc = tf.concat((times_acc, [(train_enter, train_elapsed)]), axis=0)\n",
        "            values_acc = tf.concat((values_acc, [values[-1]]), axis=0)\n",
        "        \n",
        "        epoch_elapsed = time.perf_counter() - epoch_enter\n",
        "        # Record epoch informations\n",
        "        steps_acc = tf.concat((steps_acc, [[\"Epoch\"]]), axis=0)\n",
        "        times_acc = tf.concat((times_acc, [(epoch_enter, epoch_elapsed)]), axis=0)\n",
        "        values_acc = tf.concat((values_acc, [[-1, epoch_num, -1]]), axis=0)\n",
        "        time.sleep(0.001)\n",
        "    \n",
        "    tf.print(\"Execution time:\", time.perf_counter() - start_time)\n",
        "    return {\"steps\": steps_acc, \"times\": times_acc, \"values\": values_acc}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jw_WSQC8-jMs"
      },
      "source": [
        "### 绘图方法\n",
        "\n",
        "最后，定义一个函数，根据 `timelined_benchmark` 函数返回的值绘制时间线。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1j73RxiP-jMw"
      },
      "outputs": [],
      "source": [
        "def draw_timeline(timeline, title, width=0.5, annotate=False, save=False):\n",
        "    # Remove invalid entries (negative times, or empty steps) from the timelines\n",
        "    invalid_mask = np.logical_and(timeline['times'] &gt; 0, timeline['steps'] != b'')[:,0]\n",
        "    steps = timeline['steps'][invalid_mask].numpy()\n",
        "    times = timeline['times'][invalid_mask].numpy()\n",
        "    values = timeline['values'][invalid_mask].numpy()\n",
        "    \n",
        "    # Get a set of different steps, ordered by the first time they are encountered\n",
        "    step_ids, indices = np.stack(np.unique(steps, return_index=True))\n",
        "    step_ids = step_ids[np.argsort(indices)]\n",
        "\n",
        "    # Shift the starting time to 0 and compute the maximal time value\n",
        "    min_time = times[:,0].min()\n",
        "    times[:,0] = (times[:,0] - min_time)\n",
        "    end = max(width, (times[:,0]+times[:,1]).max() + 0.01)\n",
        "    \n",
        "    cmap = mpl.cm.get_cmap(\"plasma\")\n",
        "    plt.close()\n",
        "    fig, axs = plt.subplots(len(step_ids), sharex=True, gridspec_kw={'hspace': 0})\n",
        "    fig.suptitle(title)\n",
        "    fig.set_size_inches(17.0, len(step_ids))\n",
        "    plt.xlim(-0.01, end)\n",
        "    \n",
        "    for i, step in enumerate(step_ids):\n",
        "        step_name = step.decode()\n",
        "        ax = axs[i]\n",
        "        ax.set_ylabel(step_name)\n",
        "        ax.set_ylim(0, 1)\n",
        "        ax.set_yticks([])\n",
        "        ax.set_xlabel(\"time (s)\")\n",
        "        ax.set_xticklabels([])\n",
        "        ax.grid(which=\"both\", axis=\"x\", color=\"k\", linestyle=\":\")\n",
        "        \n",
        "        # Get timings and annotation for the given step\n",
        "        entries_mask = np.squeeze(steps==step)\n",
        "        serie = np.unique(times[entries_mask], axis=0)\n",
        "        annotations = values[entries_mask]\n",
        "        \n",
        "        ax.broken_barh(serie, (0, 1), color=cmap(i / len(step_ids)), linewidth=1, alpha=0.66)\n",
        "        if annotate:\n",
        "            for j, (start, width) in enumerate(serie):\n",
        "                annotation = \"\\n\".join([f\"{l}: {v}\" for l,v in zip((\"i\", \"e\", \"s\"), annotations[j])])\n",
        "                ax.text(start + 0.001 + (0.001 * (j % 2)), 0.55 - (0.1 * (j % 2)), annotation,\n",
        "                        horizontalalignment='left', verticalalignment='center')\n",
        "    if save:\n",
        "        plt.savefig(title.lower().translate(str.maketrans(\" \", \"_\")) + \".svg\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xto6GNdO-jM1"
      },
      "source": [
        "### 对映射函数使用封装容器\n",
        "\n",
        "要在 Eager 上下文中运行映射函数，必须将其封装在 `tf.py_function` 调用中。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "39v7JD4L-jM2"
      },
      "outputs": [],
      "source": [
        "def map_decorator(func):\n",
        "    def wrapper(steps, times, values):\n",
        "        # Use a tf.py_function to prevent auto-graph from compiling the method\n",
        "        return tf.py_function(\n",
        "            func,\n",
        "            inp=(steps, times, values),\n",
        "            Tout=(steps.dtype, times.dtype, values.dtype)\n",
        "        )\n",
        "    return wrapper"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7eJRCinb-jM5"
      },
      "source": [
        "### 流水线对比"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YwX4ndHE-jM6"
      },
      "outputs": [],
      "source": [
        "_batch_map_num_items = 50\n",
        "\n",
        "def dataset_generator_fun(*args):\n",
        "    return TimeMeasuredDataset(num_samples=_batch_map_num_items)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EwxJT2aR-jNA"
      },
      "source": [
        "#### 朴素流水线"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wLKgurx_-jNC"
      },
      "outputs": [],
      "source": [
        "@map_decorator\n",
        "def naive_map(steps, times, values):\n",
        "    map_enter = time.perf_counter()\n",
        "    time.sleep(0.001)  # Time consuming step\n",
        "    time.sleep(0.0001)  # Memory consuming step\n",
        "    map_elapsed = time.perf_counter() - map_enter\n",
        "\n",
        "    return (\n",
        "        tf.concat((steps, [[\"Map\"]]), axis=0),\n",
        "        tf.concat((times, [[map_enter, map_elapsed]]), axis=0),\n",
        "        tf.concat((values, [values[-1]]), axis=0)\n",
        "    )\n",
        "\n",
        "naive_timeline = timelined_benchmark(\n",
        "    tf.data.Dataset.range(2)\n",
        "    .flat_map(dataset_generator_fun)\n",
        "    .map(naive_map)\n",
        "    .batch(_batch_map_num_items, drop_remainder=True)\n",
        "    .unbatch(),\n",
        "    5\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EJqUMDsO-jNG"
      },
      "source": [
        "### 优化后的流水线"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HYHcwabr-jNH"
      },
      "outputs": [],
      "source": [
        "@map_decorator\n",
        "def time_consuming_map(steps, times, values):\n",
        "    map_enter = time.perf_counter()\n",
        "    time.sleep(0.001 * values.shape[0])  # Time consuming step\n",
        "    map_elapsed = time.perf_counter() - map_enter\n",
        "\n",
        "    return (\n",
        "        tf.concat((steps, tf.tile([[[\"1st map\"]]], [steps.shape[0], 1, 1])), axis=1),\n",
        "        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),\n",
        "        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)\n",
        "    )\n",
        "\n",
        "\n",
        "@map_decorator\n",
        "def memory_consuming_map(steps, times, values):\n",
        "    map_enter = time.perf_counter()\n",
        "    time.sleep(0.0001 * values.shape[0])  # Memory consuming step\n",
        "    map_elapsed = time.perf_counter() - map_enter\n",
        "\n",
        "    # Use tf.tile to handle batch dimension\n",
        "    return (\n",
        "        tf.concat((steps, tf.tile([[[\"2nd map\"]]], [steps.shape[0], 1, 1])), axis=1),\n",
        "        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),\n",
        "        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)\n",
        "    )\n",
        "\n",
        "\n",
        "optimized_timeline = timelined_benchmark(\n",
        "    tf.data.Dataset.range(2)\n",
        "    .interleave(  # Parallelize data reading\n",
        "        dataset_generator_fun,\n",
        "        num_parallel_calls=tf.data.experimental.AUTOTUNE\n",
        "    )\n",
        "    .batch(  # Vectorize your mapped function\n",
        "        _batch_map_num_items,\n",
        "        drop_remainder=True)\n",
        "    .map(  # Parallelize map transformation\n",
        "        time_consuming_map,\n",
        "        num_parallel_calls=tf.data.experimental.AUTOTUNE\n",
        "    )\n",
        "    .cache()  # Cache data\n",
        "    .map(  # Reduce memory usage\n",
        "        memory_consuming_map,\n",
        "        num_parallel_calls=tf.data.experimental.AUTOTUNE\n",
        "    )\n",
        "    .prefetch(  # Overlap producer and consumer works\n",
        "        tf.data.experimental.AUTOTUNE\n",
        "    )\n",
        "    .unbatch(),\n",
        "    5\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b_CSUbxL-jNK"
      },
      "outputs": [],
      "source": [
        "draw_timeline(naive_timeline, \"Naive\", 15)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DoovY7qr-jNR"
      },
      "outputs": [],
      "source": [
        "draw_timeline(optimized_timeline, \"Optimized\", 15)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "data_performance.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
